{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j_LlXHYcmRaC"
      },
      "source": [
        "# Lookahead Optimizer on MNIST\n",
        "\n",
        "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.sandbox.google.com/github/google-deepmind/optax/blob/main/examples/lookahead_mnist.ipynb)\n",
        "\n",
        "This notebook trains a simple Convolution Neural Network (CNN) for hand-written digit recognition (MNIST dataset) using {py:func}`optax.lookahead`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9cu0kFNrnJj7"
      },
      "outputs": [],
      "source": [
        "from flax import linen as nn\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import optax\n",
        "\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2Adl_l_uZs1d"
      },
      "outputs": [],
      "source": [
        "# @markdown The learning rate for the fast optimizer:\n",
        "FAST_LEARNING_RATE = 0.002 # @param{type:\"number\"}\n",
        "# @markdown The learning rate for the slow optimizer:\n",
        "SLOW_LEARNING_RATE = 0.1 # @param{type:\"number\"}\n",
        "# @markdown Number of fast optimizer steps to take before synchronizing parameters:\n",
        "SYNC_PERIOD = 5 # @param{type:\"integer\"}\n",
        "# @markdown Number of samples in each batch:\n",
        "BATCH_SIZE = 256 # @param{type:\"integer\"}\n",
        "# @markdown Total number of epochs to train for:\n",
        "N_EPOCHS = 1 # @param{type:\"integer\"}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZZej3FcOhuRE"
      },
      "source": [
        "MNIST is a dataset of 28x28 images with 1 channel. We now load the dataset using `tensorflow_datasets`, apply min-max normalization to images, shuffle the data in the train set and create batches of size `BATCH_SIZE`.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xPZ0paOehHWg"
      },
      "outputs": [],
      "source": [
        "(train_loader, test_loader), info = tfds.load(\n",
        "    \"mnist\", split=[\"train\", \"test\"], as_supervised=True, with_info=True\n",
        ")\n",
        "NUM_CLASSES = info.features[\"label\"].num_classes\n",
        "IMG_SIZE = info.features[\"image\"].shape\n",
        "\n",
        "min_max_rgb = lambda image, label: (tf.cast(image, tf.float32) / 255., label)\n",
        "train_loader = train_loader.map(min_max_rgb)\n",
        "test_loader = test_loader.map(min_max_rgb)\n",
        "\n",
        "train_loader_batched = train_loader.shuffle(\n",
        "    buffer_size=10_000, reshuffle_each_iteration=True\n",
        ").batch(BATCH_SIZE, drop_remainder=True)\n",
        "\n",
        "test_loader_batched = test_loader.batch(BATCH_SIZE, drop_remainder=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XkLaC2MlbAqa"
      },
      "source": [
        "The data is ready! Next let's define a model. Optax is agnostic to which (if any) neural network library is used. Here we use Flax to implement a simple CNN."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RppusWrcaXzX"
      },
      "outputs": [],
      "source": [
        "class CNN(nn.Module):\n",
        "  \"\"\"A simple CNN model.\"\"\"\n",
        "\n",
        "  @nn.compact\n",
        "  def __call__(self, x):\n",
        "    x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n",
        "    x = nn.relu(x)\n",
        "    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n",
        "    x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n",
        "    x = nn.relu(x)\n",
        "    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n",
        "    x = x.reshape((x.shape[0], -1))  # flatten\n",
        "    x = nn.Dense(features=256)(x)\n",
        "    x = nn.relu(x)\n",
        "    x = nn.Dense(features=10)(x)\n",
        "    return x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DKOi55MgdPyp"
      },
      "outputs": [],
      "source": [
        "net = CNN()\n",
        "\n",
        "@jax.jit\n",
        "def predict(params, inputs):\n",
        "  return net.apply({'params': params}, inputs)\n",
        "\n",
        "\n",
        "@jax.jit\n",
        "def loss_accuracy(params, data):\n",
        "  \"\"\"Computes loss and accuracy over a mini-batch.\n",
        "\n",
        "  Args:\n",
        "    params: parameters of the model.\n",
        "    bn_params: state of the model.\n",
        "    data: tuple of (inputs, labels).\n",
        "    is_training: if true, uses train mode, otherwise uses eval mode.\n",
        "\n",
        "  Returns:\n",
        "    loss: float\n",
        "  \"\"\"\n",
        "  inputs, labels = data\n",
        "  logits = predict(params, inputs)\n",
        "  loss = optax.softmax_cross_entropy_with_integer_labels(\n",
        "      logits=logits, labels=labels\n",
        "  ).mean()\n",
        "  accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)\n",
        "  return loss, {\"accuracy\": accuracy}\n",
        "\n",
        "@jax.jit\n",
        "def update_model(state, grads):\n",
        "  return state.apply_gradients(grads=grads)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0eB2dhIpjTIi"
      },
      "source": [
        "Next we need to initialize CNN parameters and solver state. We also define a convenience function `dataset_stats` that we'll call once per epoch to collect the loss and accuracy of our solver over the test set. We will be using the Lookahead optimizer.\n",
        "Its wrapper keeps a pair of slow and fast parameters. To\n",
        "initialize them, we create a pair of synchronized parameters from the\n",
        "initial model parameters.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PBnbq7gui34L"
      },
      "outputs": [],
      "source": [
        "fast_solver = optax.adam(FAST_LEARNING_RATE)\n",
        "solver = optax.lookahead(fast_solver, SYNC_PERIOD, SLOW_LEARNING_RATE)\n",
        "rng = jax.random.PRNGKey(0)\n",
        "dummy_data = jnp.ones((1,) + IMG_SIZE, dtype=jnp.float32)\n",
        "\n",
        "params = net.init({\"params\": rng}, dummy_data)[\"params\"]\n",
        "\n",
        "# Initializes the lookahead optimizer with the initial model parameters.\n",
        "params = optax.LookaheadParams.init_synced(params)\n",
        "solver_state = solver.init(params)\n",
        "\n",
        "def dataset_stats(params, data_loader):\n",
        "  \"\"\"Computes loss and accuracy over the dataset `data_loader`.\"\"\"\n",
        "  all_accuracy = []\n",
        "  all_loss = []\n",
        "  for batch in data_loader.as_numpy_iterator():\n",
        "    batch_loss, batch_aux = loss_accuracy(params, batch)\n",
        "    all_loss.append(batch_loss)\n",
        "    all_accuracy.append(batch_aux[\"accuracy\"])\n",
        "  return {\"loss\": np.mean(all_loss), \"accuracy\": np.mean(all_accuracy)}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4H6GWNJf0XTY"
      },
      "source": [
        "Finally, we do the actual training. The next cell train the model for  `N_EPOCHS`. Within each epoch we iterate over the batched loader `train_loader_batched`, and once per epoch we also compute the test set accuracy and loss."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DeQr0urBjoDj"
      },
      "outputs": [],
      "source": [
        "train_accuracy = []\n",
        "train_losses = []\n",
        "\n",
        "# Computes test set accuracy at initialization.\n",
        "test_stats = dataset_stats(params.slow, test_loader_batched)\n",
        "test_accuracy = [test_stats[\"accuracy\"]]\n",
        "test_losses = [test_stats[\"loss\"]]\n",
        "\n",
        "\n",
        "@jax.jit\n",
        "def train_step(params, solver_state, batch):\n",
        "  # Performs a one step update.\n",
        "  (loss, aux), grad = jax.value_and_grad(loss_accuracy, has_aux=True)(\n",
        "      params.fast, batch\n",
        "  )\n",
        "  updates, solver_state = solver.update(grad, solver_state, params)\n",
        "  params = optax.apply_updates(params, updates)\n",
        "  return params, solver_state, loss, aux\n",
        "\n",
        "\n",
        "for epoch in range(N_EPOCHS):\n",
        "  train_accuracy_epoch = []\n",
        "  train_losses_epoch = []\n",
        "\n",
        "  for step, train_batch in enumerate(train_loader_batched.as_numpy_iterator()):\n",
        "    params, solver_state, train_loss, train_aux = train_step(\n",
        "        params, solver_state, train_batch\n",
        "    )\n",
        "    train_accuracy_epoch.append(train_aux[\"accuracy\"])\n",
        "    train_losses_epoch.append(train_loss)\n",
        "    if step % 20 == 0:\n",
        "      print(\n",
        "          f\"step {step}, train loss: {train_loss:.2e}, train accuracy:\"\n",
        "          f\" {train_aux['accuracy']:.2f}\"\n",
        "      )\n",
        "\n",
        "  # Validation is done on the slow lookahead parameters.\n",
        "  test_stats = dataset_stats(params.slow, test_loader_batched)\n",
        "  test_accuracy.append(test_stats[\"accuracy\"])\n",
        "  test_losses.append(test_stats[\"loss\"])\n",
        "  train_accuracy.append(np.mean(train_accuracy_epoch))\n",
        "  train_losses.append(np.mean(train_losses_epoch))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yyS1oRZBtytP"
      },
      "outputs": [],
      "source": [
        "f\"Improved accuracy on test DS from {test_accuracy[0]} to {test_accuracy[-1]}\""
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "execution": {
      "timeout": -1
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
